#ifndef LIGHTGBM_UTILS_ARRAY_AGRS_H_
#define LIGHTGBM_UTILS_ARRAY_AGRS_H_

#include <vector>
#include <algorithm>
#include <LightGBM/utils/openmp_wrapper.h>

namespace LightGBM {

/*!
* \brief Contains some operation for a array, e.g. ArgMax, TopK.
*/
template<typename VAL_T>
class ArrayArgs {
public:
  inline static size_t ArgMaxMT(const std::vector<VAL_T>& array) {
    int num_threads = 1;
#pragma omp parallel
#pragma omp master
    {
      num_threads = omp_get_num_threads();
    }
    int step = std::max(1, (static_cast<int>(array.size()) + num_threads - 1) / num_threads);
    std::vector<size_t> arg_maxs(num_threads, 0);
    #pragma omp parallel for schedule(static,1)
    for (int i = 0; i < num_threads; ++i) {
      size_t start = step * i;
      if (start >= array.size()) { continue; }
      size_t end = std::min(array.size(), start + step);
      size_t arg_max = start;
      for (size_t j = start + 1; j < end; ++j) {
        if (array[j] > array[arg_max]) {
          arg_max = j;
        }
      }
      arg_maxs[i] = arg_max;
    }
    size_t ret = arg_maxs[0];
    for (int i = 1; i < num_threads; ++i) {
      if (array[arg_maxs[i]] > array[ret]) {
        ret = arg_maxs[i];
      }
    }
    return ret;
  }
  inline static size_t ArgMax(const std::vector<VAL_T>& array) {
    if (array.empty()) {
      return 0;
    }
    if (array.size() > 100) {
      return ArgMaxMT(array);
    } else {
      size_t arg_max = 0;
      for (size_t i = 1; i < array.size(); ++i) {
        if (array[i] > array[arg_max]) {
          arg_max = i;
        }
      }
      return arg_max;
    }
  }

  inline static size_t ArgMin(const std::vector<VAL_T>& array) {
    if (array.empty()) {
      return 0;
    }
    size_t arg_min = 0;
    for (size_t i = 1; i < array.size(); ++i) {
      if (array[i] < array[arg_min]) {
        arg_min = i;
      }
    }
    return arg_min;
  }

  inline static size_t ArgMax(const VAL_T* array, size_t n) {
    if (n <= 0) {
      return 0;
    }
    size_t arg_max = 0;
    for (size_t i = 1; i < n; ++i) {
      if (array[i] > array[arg_max]) {
        arg_max = i;
      }
    }
    return arg_max;
  }

  inline static size_t ArgMin(const VAL_T* array, size_t n) {
    if (n <= 0) {
      return 0;
    }
    size_t arg_min = 0;
    for (size_t i = 1; i < n; ++i) {
      if (array[i] < array[arg_min]) {
        arg_min = i;
      }
    }
    return arg_min;
  }

  inline static void Partition(std::vector<VAL_T>* arr, int start, int end, int* l, int* r) {
    int i = start - 1;
    int j = end - 1;
    int p = i;
    int q = j;
    if (start >= end) {
      return;
    }
    std::vector<VAL_T>& ref = *arr;
    VAL_T v = ref[end - 1];
    for (;;) {
      while (ref[++i] > v);
      while (v > ref[--j]) { if (j == start) { break; } }
      if (i >= j) { break; }
      std::swap(ref[i], ref[j]);
      if (ref[i] == v) { p++; std::swap(ref[p], ref[i]); }
      if (v == ref[j]) { q--; std::swap(ref[j], ref[q]); }
    }
    std::swap(ref[i], ref[end - 1]);
    j = i - 1;
    i = i + 1;
    for (int k = start; k <= p; k++, j--) { std::swap(ref[k], ref[j]); }
    for (int k = end - 2; k >= q; k--, i++) { std::swap(ref[i], ref[k]); }
    *l = j;
    *r = i;
  };

  inline static int ArgMaxAtK(std::vector<VAL_T>* arr, int start, int end, int k) {
    if (start >= end - 1) {
      return start;
    }
    int l = start;
    int r = end - 1;
    Partition(arr, start, end, &l, &r);
    if ((k > l && k < r) || l == 0 || r == end - 1) {
      return k;
    } else if (k <= l) {
      return ArgMaxAtK(arr, start, l, k);
    } else {
      return ArgMaxAtK(arr, r, end, k);
    }
  }

  inline static void MaxK(const std::vector<VAL_T>& array, int k, std::vector<VAL_T>* out) {
    out->clear();
    if (k <= 0) {
      return;
    }
    for (auto val : array) {
      out->push_back(val);
    }
    if (static_cast<size_t>(k) >= array.size()) {
      return;
    }
    ArgMaxAtK(out, 0, static_cast<int>(out->size()), k - 1);
    out->erase(out->begin() + k, out->end());
  }

};

}  // namespace LightGBM

#endif   // LightGBM_UTILS_ARRAY_AGRS_H_

